--- external/stylegan/training/loss.py	2023-04-06 03:45:26.165066348 +0000
+++ external_reference/stylegan/training/loss.py	2023-04-06 03:41:03.250603352 +0000
@@ -14,6 +14,13 @@
 from torch_utils.ops import conv2d_gradfix
 from torch_utils.ops import upfirdn2d
 
+import copy
+import random
+import torch.nn.functional as F
+from utils import camera_util, losses, regularizers
+from external.gsn.models.diff_augment import DiffAugment
+from utils.utils import interpolate
+
 #----------------------------------------------------------------------------
 
 class Loss:
@@ -23,7 +30,11 @@
 #----------------------------------------------------------------------------
 
 class StyleGAN2Loss(Loss):
-    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0):
+    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10,
+                 style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2,
+                 pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0,
+                 blur_fade_kimg=0, training_mode=None, loss_layout_kwargs=None,
+                 loss_upsampler_kwargs=None, loss_sky_kwargs=None):
         super().__init__()
         self.device             = device
         self.G                  = G
@@ -38,18 +49,73 @@
         self.pl_mean            = torch.zeros([], device=device)
         self.blur_init_sigma    = blur_init_sigma
         self.blur_fade_kimg     = blur_fade_kimg
+        self.loss_layout_kwargs = loss_layout_kwargs
+        self.loss_upsampler_kwargs = loss_upsampler_kwargs
+        self.loss_sky_kwargs = loss_sky_kwargs
+        self.training_mode = training_mode
+        self.loss_l1 = losses.L1_Loss().to(device)
+
+    def run_G(self, z, c, camera_params, real_img_masked=None, real_acc=None, update_emas=False):
+        assert(self.style_mixing_prob == 0) # not implemented
+        if self.training_mode == 'layout':
+            img, infos = self.G(z=z, c=c, camera_params=camera_params,
+                                update_emas=update_emas, extras=['opacity_regularization'])
+            return img, infos
+        elif self.training_mode == 'upsampler':
+            upsampler_ws, feature, thumb, extras = self.G.mapping(z, c, update_emas=update_emas)
+            img = self.G.synthesis(upsampler_ws, feature, thumb, extras=extras, update_emas=update_emas)
+            return img, dict(thumb=thumb, img=img, ws=upsampler_ws, extras=extras)
+        elif self.training_mode == 'sky':
+            ws = self.G.mapping(z, c, update_emas=update_emas)
+            # taken from anyres 360 loss
+            multiply = True if random.uniform(0, 1) < self.loss_sky_kwargs.mask_prob else False
+            input_layer = self.G.G.synthesis.input # extract input tensor from synthesis network
+            crop_start = random.randint(0, 360 // input_layer.fov * input_layer.frame_size[0] - 1)
+            crop_fn = lambda grid : grid[:, :, crop_start:crop_start+input_layer.size[0], :]
+            img_base = self.G.synthesis(ws, real_img_masked, real_acc,
+                                        multiply=False, crop_fn=crop_fn, update_emas=update_emas)
+            crop_shift = crop_start + input_layer.frame_size[0]
+            # generate shifted frame for cross-frame discriminator
+            crop_fn_shift = lambda grid : grid[:, :, crop_shift:crop_shift+input_layer.size[0], :]
+            img_shifted = self.G.synthesis(ws, real_img_masked, real_acc,
+                                           multiply=False, crop_fn=crop_fn_shift,
+                                           update_emas=update_emas)
+            img_splice = torch.cat([img_base, img_shifted], dim=3)
+            img_size = img_base.shape[-1]
+            splice_start = random.randint(0, img_size)
+            img = img_splice[:, :, :, splice_start:splice_start+img_size]
+            # multiply real img only after generating both splices
+            if multiply:
+                img = img * (1-real_acc) + real_img_masked * real_acc
+            # img = self.G.synthesis(ws, real_img_masked, real_acc, update_emas=update_emas)
+            return img, ws
+
+    def run_D(self, infos, c, blur_sigma=0, update_emas=False):
+        if self.training_mode == 'layout':
+            img = infos['rgb']
+            if self.loss_layout_kwargs.concat_depth:
+                depth = infos['depth']
+                # D_shape = self.D.img_resolution
+                # depth = F.interpolate(depth, size=D_shape, mode='bilinear', align_corners=False)
+                img = torch.cat([img, depth], dim=1)
+            if self.loss_layout_kwargs.concat_acc:
+                acc = infos['acc']
+                # D_shape = self.D.img_resolution
+                # acc = F.interpolate(acc, size=D_shape, mode='bilinear', align_corners=False)
+                img = torch.cat([img, acc], dim=1)
+            if self.loss_layout_kwargs.aug_policy:
+                img = DiffAugment(img, normalize=True, policy=self.loss_layout_kwargs.aug_policy)
+            assert(blur_sigma == 0) # using GSN DiffAugment module
+        elif self.training_mode == 'upsampler':
+            img = infos['img']
+            if self.loss_upsampler_kwargs.d_ignore_depth_acc:
+                # use 0.5 constant as input into depth and acc channels
+                b, c, h, w = img.shape
+                ignore_tensor = torch.ones(b, c-3, h, w).to(img.device) * 0.5
+                img = torch.cat([img[:, :3], ignore_tensor], dim=1)
+        elif self.training_mode == 'sky':
+            img = infos # output of sky generator is img itself
 
-    def run_G(self, z, c, update_emas=False):
-        ws = self.G.mapping(z, c, update_emas=update_emas)
-        if self.style_mixing_prob > 0:
-            with torch.autograd.profiler.record_function('style_mixing'):
-                cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
-                cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
-                ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
-        img = self.G.synthesis(ws, update_emas=update_emas)
-        return img, ws
-
-    def run_D(self, img, c, blur_sigma=0, update_emas=False):
         blur_size = np.floor(blur_sigma * 3)
         if blur_size > 0:
             with torch.autograd.profiler.record_function('blur'):
@@ -57,10 +123,17 @@
                 img = upfirdn2d.filter2d(img, f / f.sum())
         if self.augment_pipe is not None:
             img = self.augment_pipe(img)
+
+        if self.D.recon:
+            assert(self.training_mode == 'layout')
+            # handle return elements with reconstruction discriminator
+            logits, recon = self.D(img, c, update_emas=update_emas)
+            return logits, img, recon
+
         logits = self.D(img, c, update_emas=update_emas)
         return logits
 
-    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
+    def accumulate_gradients(self, phase, real_img_infos, real_c, gen_z, gen_c, gain, cur_nimg):
         assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
         if self.pl_weight == 0:
             phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
@@ -68,15 +141,72 @@
             phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
         blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
 
+        camera_params = copy.deepcopy(real_img_infos['camera_params'])
+        del camera_params['Rt']
+        real_img = real_img_infos['rgb'] # img with sky masked
+        real_orig = real_img_infos['orig'] # img with sky 
+        real_depth = real_img_infos['depth']
+        real_acc = real_img_infos['acc']
+
         # Gmain: Maximize logits for generated images.
         if phase in ['Gmain', 'Gboth']:
             with torch.autograd.profiler.record_function('Gmain_forward'):
-                gen_img, _gen_ws = self.run_G(gen_z, gen_c)
-                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
-                training_stats.report('Loss/scores/fake', gen_logits)
-                training_stats.report('Loss/signs/fake', gen_logits.sign())
-                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
-                training_stats.report('Loss/G/loss', loss_Gmain)
+                if self.training_mode == 'layout':
+                    gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params)
+                    gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+                    training_stats.report('Loss/G/loss', loss_Gmain)
+                    if self.loss_layout_kwargs.lambda_finite_difference > 0:
+                        if self.loss_layout_kwargs.lambda_ramp_end > 0:
+                            ramp_multiplier = cur_nimg / self.loss_layout_kwargs.lambda_ramp_end
+                        else:
+                            ramp_multiplier = 1
+                        ramp_multiplier = np.clip(ramp_multiplier, 0, 1)
+                        training_stats.report('Loss/G/reg_ramp', ramp_multiplier)
+                        finite_diff = regularizers.ray_finite_difference(gen_infos['extra_outputs'])
+                        loss_Gmain = loss_Gmain + ramp_multiplier * self.loss_layout_kwargs.lambda_finite_difference * finite_diff
+                        training_stats.report('Loss/G/reg_ray', finite_diff)
+                    training_stats.report('Loss/G/loss_total', loss_Gmain)
+                elif self.training_mode == 'upsampler':
+                    gen_img, infos = self.run_G(gen_z, gen_c, camera_params)
+                    thumb = infos['thumb']
+                    gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+                    training_stats.report('Loss/G/loss', loss_Gmain)
+                    if self.loss_upsampler_kwargs.lambda_rec > 0:
+                        # downsample gen img, it should match thumb
+                        ch = thumb.shape[1]
+                        gen_down = interpolate(gen_img, thumb.shape[-2:])
+                        l1_loss = self.loss_l1(gen_down, thumb)
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_rec * l1_loss)
+                        training_stats.report('Loss/G/loss_rec_l1', l1_loss)
+                    if self.loss_upsampler_kwargs.lambda_up > 0:
+                        # upsample the depth and acc maps
+                        depth_and_acc_thumb_up = interpolate(thumb[:, 3:], gen_img.shape[-2:])
+                        depth_and_acc_gen = gen_img[:, 3:]
+                        l1_loss = self.loss_l1(depth_and_acc_gen, depth_and_acc_thumb_up)
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_up * l1_loss)
+                        training_stats.report('Loss/G/loss_up_l1', l1_loss)
+                    if self.loss_upsampler_kwargs.lambda_gray_pixel > 0:
+                        acc_mask = (gen_img[:, -1:] > 0.5).float().detach()
+                        pixel_sum = torch.sum(torch.abs(gen_img[:, :3]), dim=1, keepdim=True)
+                        penalty = torch.exp(-self.loss_upsampler_kwargs.lambda_gray_pixel_falloff * pixel_sum) * acc_mask
+                        reg_loss = torch.mean(penalty, dim=(1, 2, 3))
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_gray_pixel * reg_loss)
+                        training_stats.report('Loss/G/loss_reg_gray', reg_loss)
+                    training_stats.report('Loss/G/loss_total', loss_Gmain)
+                elif self.training_mode == 'sky':
+                    gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc)
+                    gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+                    training_stats.report('Loss/G/loss', loss_Gmain)
+
             with torch.autograd.profiler.record_function('Gmain_backward'):
                 loss_Gmain.mean().mul(gain).backward()
 
@@ -84,7 +214,24 @@
         if phase in ['Greg', 'Gboth']:
             with torch.autograd.profiler.record_function('Gpl_forward'):
                 batch_size = gen_z.shape[0] // self.pl_batch_shrink
-                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
+                camera_params_batch = {k: v[:batch_size] for k, v in camera_params.items()}
+                if self.training_mode == 'layout':
+                    gen_img, gen_infos = self.run_G(gen_z[:batch_size],
+                                                    gen_c[:batch_size],
+                                                    camera_params_batch)
+                    gen_ws = gen_infos['ws']
+                if self.training_mode == 'upsampler':
+                    gen_img, infos = self.run_G(gen_z[:batch_size],
+                                                gen_c[:batch_size],
+                                                camera_params_batch)
+                    gen_ws = infos['ws']
+                if self.training_mode == 'sky':
+                    gen_img, gen_ws = self.run_G(gen_z[:batch_size],
+                                                 gen_c[:batch_size],
+                                                 camera_params_batch,
+                                                 real_img[:batch_size],
+                                                 real_acc[:batch_size])
+                # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
                 pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                 with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
                     pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
@@ -102,11 +249,24 @@
         loss_Dgen = 0
         if phase in ['Dmain', 'Dboth']:
             with torch.autograd.profiler.record_function('Dgen_forward'):
-                gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True)
-                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
-                training_stats.report('Loss/scores/fake', gen_logits)
-                training_stats.report('Loss/signs/fake', gen_logits.sign())
-                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+                if self.training_mode == 'layout':
+                    gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
+                    gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+                elif self.training_mode == 'upsampler':
+                    gen_img, infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
+                    gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+                elif self.training_mode == 'sky':
+                    gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc, update_emas=True)
+                    gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
+                    training_stats.report('Loss/scores/fake', gen_logits)
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
             with torch.autograd.profiler.record_function('Dgen_backward'):
                 loss_Dgen.mean().mul(gain).backward()
 
@@ -115,8 +275,29 @@
         if phase in ['Dmain', 'Dreg', 'Dboth']:
             name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
             with torch.autograd.profiler.record_function(name + '_forward'):
-                real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
-                real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
+                if self.training_mode == 'layout':
+                    real_img_tmp = {
+                        'rgb': real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']),
+                        'depth': real_depth.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
+                        'acc': real_acc.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
+                    }
+                    r1_grads_input = real_img_tmp['rgb']
+                    real_logits, real_disc_in, real_recon = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
+
+                elif self.training_mode == 'upsampler':
+                    real_input = real_img
+                    if self.G.upsampler.synthesis.num_additional_feature_channels > 0:
+                        real_input = torch.cat([real_input, real_depth], dim=1)
+                    if self.G.upsampler.synthesis.num_additional_feature_channels > 1:
+                        real_input = torch.cat([real_input, real_acc], dim=1)
+                    real_img_tmp = real_input.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
+                    r1_grads_input = real_img_tmp
+                    infos_D = dict(img=real_img_tmp)
+                    real_logits = self.run_D(infos_D, real_c, blur_sigma=blur_sigma)
+                elif self.training_mode == 'sky':
+                    real_img_tmp = real_orig.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
+                    r1_grads_input = real_img_tmp
+                    real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
                 training_stats.report('Loss/scores/real', real_logits)
                 training_stats.report('Loss/signs/real', real_logits.sign())
 
@@ -128,13 +309,20 @@
                 loss_Dr1 = 0
                 if phase in ['Dreg', 'Dboth']:
                     with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
-                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
+                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[r1_grads_input], create_graph=True, only_inputs=True)[0]
                     r1_penalty = r1_grads.square().sum([1,2,3])
                     loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                     training_stats.report('Loss/r1_penalty', r1_penalty)
                     training_stats.report('Loss/D/reg', loss_Dr1)
 
+                loss_Drecon = 0
+                if self.D.recon:
+                    assert(self.training_mode == 'layout')
+                    if phase in ['Dmain', 'Dboth']:
+                        loss_Drecon = F.mse_loss(real_disc_in, real_recon) * self.loss_layout_kwargs.recon_weight
+                        training_stats.report('Loss/D/recon_loss', loss_Drecon)
+
             with torch.autograd.profiler.record_function(name + '_backward'):
-                (loss_Dreal + loss_Dr1).mean().mul(gain).backward()
+                (loss_Dreal + loss_Dr1 + loss_Drecon).mean().mul(gain).backward()
 
 #----------------------------------------------------------------------------
